{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tabular models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.tabular import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tabular data should be in a Pandas `DataFrame`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.ADULT_SAMPLE)\n",
    "df = pd.read_csv(path/'adult.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dep_var = '>=50k'\n",
    "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
    "cont_names = ['age', 'fnlwgt', 'education-num']\n",
    "procs = [FillMissing, Categorify, Normalize]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n",
    "                           .split_by_idx(list(range(800,1000)))\n",
    "                           .label_from_df(cols=dep_var)\n",
    "                           .add_test(test, label=0)\n",
    "                           .databunch())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <tr>\n",
       "    <th>workclass</th>\n",
       "    <th>education</th>\n",
       "    <th>marital-status</th>\n",
       "    <th>occupation</th>\n",
       "    <th>relationship</th>\n",
       "    <th>race</th>\n",
       "    <th>education-num_na</th>\n",
       "    <th>age</th>\n",
       "    <th>fnlwgt</th>\n",
       "    <th>education-num</th>\n",
       "    <th>target</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> 12th</th>\n",
       "    <th> Never-married</th>\n",
       "    <th> Other-service</th>\n",
       "    <th> Not-in-family</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-1.2158</th>\n",
       "    <th>0.5291</th>\n",
       "    <th>-0.8135</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> Some-college</th>\n",
       "    <th> Never-married</th>\n",
       "    <th> Exec-managerial</th>\n",
       "    <th> Not-in-family</th>\n",
       "    <th> Black</th>\n",
       "    <th>False</th>\n",
       "    <th>-1.0692</th>\n",
       "    <th>0.0235</th>\n",
       "    <th>-0.0312</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Self-emp-not-inc</th>\n",
       "    <th> 1st-4th</th>\n",
       "    <th> Widowed</th>\n",
       "    <th> Other-service</th>\n",
       "    <th> Not-in-family</th>\n",
       "    <th> Black</th>\n",
       "    <th>False</th>\n",
       "    <th>2.0826</th>\n",
       "    <th>-0.7946</th>\n",
       "    <th>-3.1604</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> Prof-school</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> Prof-specialty</th>\n",
       "    <th> Husband</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.1896</th>\n",
       "    <th>-0.3709</th>\n",
       "    <th>1.9245</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Never-married</th>\n",
       "    <th> Handlers-cleaners</th>\n",
       "    <th> Own-child</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-1.2158</th>\n",
       "    <th>0.3155</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> ?</th>\n",
       "    <th> Some-college</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> ?</th>\n",
       "    <th> Wife</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.8493</th>\n",
       "    <th>-0.1475</th>\n",
       "    <th>-0.0312</th>\n",
       "    <th>1</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> 11th</th>\n",
       "    <th> Divorced</th>\n",
       "    <th> Craft-repair</th>\n",
       "    <th> Own-child</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.4828</th>\n",
       "    <th>0.1595</th>\n",
       "    <th>-1.2046</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> ?</th>\n",
       "    <th> 10th</th>\n",
       "    <th> Separated</th>\n",
       "    <th>#na#</th>\n",
       "    <th> Not-in-family</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.7760</th>\n",
       "    <th>2.2296</th>\n",
       "    <th>-1.5958</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> Craft-repair</th>\n",
       "    <th> Husband</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>0.1036</th>\n",
       "    <th>0.0909</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> ?</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> ?</th>\n",
       "    <th> Husband</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>1.7161</th>\n",
       "    <th>-0.5733</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th>0</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(rows=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 00:04\n",
      "epoch  train_loss  valid_loss  accuracy\n",
      "1      0.364682    0.385048    0.820000  (00:04)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learn.fit(1, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "row = df.iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, tensor(0), tensor([0.6365, 0.3635]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict(row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
